import queue
import threading
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Optional

import torch
from vllm.logger import logger
from vllm.v1.core.kv_cache_utils import BlockHash

from vllm_ascend.distributed.kvpool.backend.backend import Backend

# isort: off
from vllm_ascend.distributed.kvpool.config_data import (ChunkedTokenDatabase,
                                                        LasyerMultiBlockReqMeta
                                                        )
# isort: on


class KVTransferThread(threading.Thread):

    def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
                 tp_rank: int, dcp_size: int, ready_event: threading.Event,
                 name: str):
        super().__init__(daemon=True, name=name)
        self.m_store = m_store
        self.ready_event = ready_event
        self.tp_rank = tp_rank
        self.dcp_size = dcp_size
        self.token_database = token_database
        self.done_task_lock = threading.Lock()
        self.request_queue: queue.Queue[Any] = queue.Queue()
        # TODO(jianzs): make this configurable
        self.executor = ThreadPoolExecutor(max_workers=32)
        self.finished_requests: set[str] = set()

    def add_request(
        self,
        req_id: str,
        token_len: int,
        block_ids: list[int],
        block_hashes: list[BlockHash],
        mask_num: int = 0,
        is_last_chunk: Optional[bool] = None,
    ) -> torch.Tensor:
        req = ({
            "req_id": req_id,
            "token_len": token_len,
            "block_ids": block_ids,
            "block_hashes": block_hashes,
            "mask_num": mask_num,
            "is_last_chunk": is_last_chunk,
        })
        self.request_queue.put(req)

    def get_and_clear_finished_requests(self) -> set[str]:
        """
        Get and clear the requests that have been completed.
        Returns:
            A set of request IDs that have been completed.
        """
        with self.done_task_lock:
            finished_requests = self.finished_requests.copy()
            self.finished_requests.clear()
        return finished_requests

    def set_finished_request(self, req_id):
        with self.done_task_lock:
            self.finished_requests.add(req_id)

    def run(self):
        """Run the thread to handle KV cache transfer requests."""
        self.m_store.set_device()
        self.ready_event.set()
        while True:
            try:
                request_data = self.request_queue.get()
                if request_data is None:
                    logger.warning("Received a None request!")
                    self.request_queue.task_done()
                    continue
                self._handle_request(request_data)
            except Exception as e:
                logger.error(f"Error in KVCacheTransferThread: {e}")

    def _handle_request(self, req_meta: dict[str, Any]):
        pass


class KVCacheStoreSendingThread(KVTransferThread):

    def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
                 tp_rank: int, dcp_size: int, put_step: int,
                 ready_event: threading.Event):
        super().__init__(m_store,
                         token_database,
                         tp_rank,
                         dcp_size,
                         ready_event,
                         name="KVCacheSendingThread")
        self.put_step = put_step

    def _handle_request(self, req_meta: dict[str, Any]):
        token_len = req_meta["token_len"]
        mask_num = req_meta["mask_num"]
        block_ids = req_meta["block_ids"]
        block_hashes = req_meta["block_hashes"]
        req_id = req_meta["req_id"]
        is_last_chunk = req_meta["is_last_chunk"]
        addr_list = []
        size_list = []
        key_list = []
        for start, end, key in self.token_database.process_tokens(
                token_len, block_hashes, mask_num):
            addr, size, _ = self.token_database.prepare_value(
                start, end, block_ids)
            key_list.append(key.to_string())
            addr_list.append(addr)
            size_list.append(size)
        if self.dcp_size > 1:
            self.m_store.put(key_list, addr_list, size_list)
        else:
            key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step]
            addr_list_tp = addr_list[self.tp_rank %
                                     self.put_step::self.put_step]
            size_list_tp = size_list[self.tp_rank %
                                     self.put_step::self.put_step]
            if key_list_tp:
                self.m_store.put(key_list_tp, addr_list_tp, size_list_tp)
        if is_last_chunk:
            self.set_finished_request(req_id)
        self.request_queue.task_done()


class KVCacheStoreRecvingThread(KVTransferThread):

    def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
                 tp_rank: int, dcp_size: int, ready_event: threading.Event):
        super().__init__(m_store,
                         token_database,
                         tp_rank,
                         dcp_size,
                         ready_event,
                         name="KVCacheStoreRecvingThread")

    def _handle_request(self, req_meta: dict[str, Any]):
        token_len = req_meta["token_len"]
        mask_num = req_meta["mask_num"]
        block_ids = req_meta["block_ids"]
        req_id = req_meta["req_id"]
        block_hashes = req_meta["block_hashes"]
        addr_list = []
        size_list = []
        key_list = []
        for start, end, key in self.token_database.process_tokens(
                token_len, block_hashes, mask_num):
            addr, size, _ = self.token_database.prepare_value(
                start, end, block_ids)
            key_list.append(key.to_string())
            addr_list.append(addr)
            size_list.append(size)
        key_list_c = key_list[self.tp_rank %
                              len(key_list):] + key_list[:self.tp_rank %
                                                         len(key_list)]
        addr_list_c = addr_list[self.tp_rank %
                                len(addr_list):] + addr_list[:self.tp_rank %
                                                             len(addr_list)]
        size_list_c = size_list[self.tp_rank %
                                len(size_list):] + size_list[:self.tp_rank %
                                                             len(size_list)]
        self.m_store.get(key_list_c, addr_list_c, size_list_c)
        self.set_finished_request(req_id)
        self.request_queue.task_done()


class KVCacheStoreLayerSendingThread(KVTransferThread):

    def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
                 tp_rank: int, dcp_size: int, put_step: int,
                 ready_event: threading.Event, num_layers: int):
        super().__init__(m_store,
                         token_database,
                         tp_rank,
                         dcp_size,
                         ready_event,
                         name="KVCacheStoreLayerSendingThread")
        self.final_layer_id = num_layers - 1
        self.put_step = put_step

    def add_request(  # type: ignore[override]
            self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor:
        self.request_queue.put(req_meta)

    def _handle_request(  # type: ignore[override]
            self, req_meta: LasyerMultiBlockReqMeta):
        addr_list = []
        size_list = []
        key_list = []
        for index, key in enumerate(req_meta.keys):
            addr, size = self.token_database.prepare_value_layer(
                req_meta.starts[index], req_meta.ends[index],
                req_meta.block_ids, req_meta.layer_id)
            key_list.append(key.to_string())
            addr_list.append(addr)
            size_list.append(size)
        if self.dcp_size > 1:
            self.m_store.put(key_list, addr_list, size_list)
        else:
            key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step]
            addr_list_tp = addr_list[self.tp_rank %
                                     self.put_step::self.put_step]
            size_list_tp = size_list[self.tp_rank %
                                     self.put_step::self.put_step]
            if key_list_tp:
                self.m_store.put(key_list_tp, addr_list_tp, size_list_tp)
        if req_meta.layer_id == self.final_layer_id and req_meta.is_last_chunk:
            self.set_finished_request(req_meta.req_id)
        self.request_queue.task_done()


class KVCacheStoreLayerRecvingThread(KVTransferThread):

    def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
                 tp_rank: int, dcp_size: int, ready_event: threading.Event,
                 get_event: threading.Event):
        super().__init__(m_store,
                         token_database,
                         tp_rank,
                         dcp_size,
                         ready_event,
                         name="KVCacheStoreLayerRecvingThread")
        self.get_event = get_event

    def add_request(  # type: ignore[override]
            self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor:
        self.request_queue.put(req_meta)

    def _handle_request(  # type: ignore[override]
            self, req_meta: LasyerMultiBlockReqMeta):
        addr_list = []
        size_list = []
        key_list = []
        for index, key in enumerate(req_meta.keys):
            addr, size = self.token_database.prepare_value_layer(
                req_meta.starts[index], req_meta.ends[index],
                req_meta.block_ids, req_meta.layer_id)
            key_list.append(key.to_string())
            addr_list.append(addr)
            size_list.append(size)
        key_list_c = key_list[self.tp_rank %
                              len(key_list):] + key_list[:self.tp_rank %
                                                         len(key_list)]
        addr_list_c = addr_list[self.tp_rank %
                                len(addr_list):] + addr_list[:self.tp_rank %
                                                             len(addr_list)]
        size_list_c = size_list[self.tp_rank %
                                len(size_list):] + size_list[:self.tp_rank %
                                                             len(size_list)]
        self.m_store.get(key_list_c, addr_list_c, size_list_c)

        self.request_queue.task_done()
        self.get_event.set()
